# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Pipelines for forward comparison."""

from ...components.expect_result_policy.cartesian_product_on_id_for_expect_result import IdCartesianProductERPC
from ...components.function_inputs_policy.cartesian_product_on_id_for_function_inputs import IdCartesianProductFIPC
from ...components.executor.exec_forward import IdentityEC
from ...components.function.run_block import RunBlockBC
from ...components.inputs.generate_inputs_from_shape import GenerateFromShapeDC
from ...components.verifier.compare_forward import CompareWithVC
from ...components.facade.me_facade import MeFacadeFC
from ...components.inputs.load_inputs_from_npy import LoadFromNpyDC
from ...components.verifier.verify_expect_from_npy import LoadFromNpyVC
from ...components.function.init_params_with_rand_and_run_block import RunBlockWithRandParamBC
from ...components.function_inputs_policy.cartesian_product_on_group_for_function_inputs import \
    GroupCartesianProductFIPC

# pylint: disable=W0105
"""
Compare an operator's result with user defined operator's result.
The pipeline is suitable for configs in a case-by-case style.

Example:
    verification_set = [
        ('TensorAdd', {
            'block': (P.TensorAdd(), {'reduce_output': False}),
            'desc_inputs': [[1, 3, 3, 4], [1, 3, 3, 4]],
            'desc_bprop': [[1, 3, 3, 4]],
            'desc_expect': {
                'compare_with': [
                    (run_uer_defined, user_defined.add)
                ],
            }
        })
    ]
"""
pipeline_for_compare_forward_with_user_defined_for_case_by_case_config = [MeFacadeFC, GenerateFromShapeDC, RunBlockBC,
                                                                          IdCartesianProductFIPC, IdentityEC,
                                                                          IdCartesianProductERPC, CompareWithVC]

"""
Compare an operator's result with numpy operator's result.
The pipeline is suitable for configs in a case-by-case style.

Example:
    verification_set = [
        ('TensorAdd', {
            'block': (P.TensorAdd(), {'reduce_output': False}),
            'desc_inputs': [[1, 3, 3, 4], [1, 3, 3, 4]],
            'desc_bprop': [[1, 3, 3, 4]],
            'desc_expect': {
                'compare_with': [
                    (run_np, np.add)
                ]
            }
        })
    ]
"""
pipeline_for_compare_forward_with_numpy_for_case_by_case_config = [MeFacadeFC, GenerateFromShapeDC, RunBlockBC,
                                                                   IdCartesianProductFIPC, IdentityEC,
                                                                   IdCartesianProductERPC, CompareWithVC]

"""
Compare an operator's result with result in npy file.
The pipeline is suitable for configs in a grouped style.

Example:
    verification_set = {
        'function': [
            {
                'id':'add',
                'group':'op-test',
                'block':(P.TensorAdd(), {'reduce_output': False}),
            }
        ],
        'inputs':  [
            {
                'id': 'add',
                'group': 'op-test',
                'desc_inputs': [
                    ('path/to/input/file_1.npy', {'dtype': np.float32}),
                    ('path/to/input/file_2.npy', {'dtype': np.float32}),
                ]
            }
        ],
        'expect': [
            {
                'id': 'add',
                'group': 'op-test',
                'desc_expect': {
                    ('path/to/expect/file.npy', {'dtype': np.float32})
                }
            }
        ]
    }
"""
pipeline_for_compare_forward_with_npy_for_group_by_group_config =\
    [LoadFromNpyDC, RunBlockWithRandParamBC, IdCartesianProductFIPC,
     IdentityEC, IdCartesianProductERPC, LoadFromNpyVC]

"""
Compare an operator's result with result in npy file.The test cases will be generated by apply function with a group id
to all inputs with same group id.
The pipeline is suitable for configs in a grouped style.

Example:
    verification_set = {
        'function': [
            {
                'id':'add',
                'group':'op-test',
                'block':(P.TensorAdd(), {'reduce_output': False}),
            }
        ],
        'inputs':  [
            {
                'id': 'add1',
                'group': 'op-test',
                'desc_inputs': [
                    ('path/to/input/file_1.npy', {'dtype': np.float32}),
                    ('path/to/input/file_2.npy', {'dtype': np.float32}),
                ]
            },
            {
                'id': 'add2',
                'group': 'op-test',
                'desc_inputs': [
                    ('path/to/input/file_3.npy', {'dtype': np.float32}),
                    ('path/to/input/file_4.npy', {'dtype': np.float32}),
                ]
            }
        ],
        'expect': [
            {
                'id': 'add1',
                'group': 'op-test',
                'desc_expect': {
                    ('path/to/expect/file.npy', {'dtype': np.float32})
                }
            },
            {
                'id': 'add2',
                'group': 'op-test',
                'desc_expect': {
                    ('path/to/expect/file.npy', {'dtype': np.float32})
                }
            }
        ]
    }
"""
pipeline_for_compare_forward_with_npy_for_group_by_group_config_using_group_policy =\
    [LoadFromNpyDC, RunBlockWithRandParamBC,
     GroupCartesianProductFIPC, IdentityEC,
     IdCartesianProductERPC, LoadFromNpyVC]
